{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建基于CBHG的拼音到汉字语言模型\n",
    "\n",
    "最近研究一下感觉用self-attention来对语言模型进行建模会更加不错，这里先做一个CBHG的tutorial吧\n",
    "\n",
    "- 数据处理\n",
    "- 模型搭建\n",
    "\n",
    "## 1. 数据处理\n",
    "- 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/zh.tsv\", 'r', encoding='utf-8') as fout:\n",
    "    data = fout.readlines()[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 100150.53it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "inputs = []\n",
    "labels = []\n",
    "for i in tqdm(range(len(data))):\n",
    "    key, pny, hanzi = data[i].split('\\t')\n",
    "    inputs.append(pny.split(' '))\n",
    "    labels.append(hanzi.strip('\\n').split(' '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['lv4', 'shi4', 'yang2', 'chun1', 'yan1', 'jing3', 'da4', 'kuai4', 'wen2', 'zhang1', 'de', 'di3', 'se4', 'si4', 'yue4', 'de', 'lin2', 'luan2', 'geng4', 'shi4', 'lv4', 'de2', 'xian1', 'huo2', 'xiu4', 'mei4', 'shi1', 'yi4', 'ang4', 'ran2'], ['ta1', 'jin3', 'ping2', 'yao1', 'bu4', 'de', 'li4', 'liang4', 'zai4', 'yong3', 'dao4', 'shang4', 'xia4', 'fan1', 'teng2', 'yong3', 'dong4', 'she2', 'xing2', 'zhuang4', 'ru2', 'hai3', 'tun2', 'yi1', 'zhi2', 'yi3', 'yi1', 'tou2', 'de', 'you1', 'shi4', 'ling3', 'xian1'], ['pao4', 'yan3', 'da3', 'hao3', 'le', 'zha4', 'yao4', 'zen3', 'me', 'zhuang1', 'yue4', 'zheng4', 'cai2', 'yao3', 'le', 'yao3', 'ya2', 'shu1', 'de', 'tuo1', 'qu4', 'yi1', 'fu2', 'guang1', 'bang3', 'zi', 'chong1', 'jin4', 'le', 'shui3', 'cuan4', 'dong4'], ['ke3', 'shei2', 'zhi1', 'wen2', 'wan2', 'hou4', 'ta1', 'yi1', 'zhao4', 'jing4', 'zi', 'zhi3', 'jian4', 'zuo3', 'xia4', 'yan3', 'jian3', 'de', 'xian4', 'you4', 'cu1', 'you4', 'hei1', 'yu3', 'you4', 'ce4', 'ming2', 'xian3', 'bu4', 'dui4', 'cheng1'], ['qi1', 'shi2', 'nian2', 'dai4', 'mo4', 'wo3', 'wai4', 'chu1', 'qiu2', 'xue2', 'mu3', 'qin1', 'ding1', 'ning2', 'wo3', 'chi1', 'fan4', 'yao4', 'xi4', 'jue2', 'man4', 'yan4', 'xue2', 'xi2', 'yao4', 'shen1', 'zuan1', 'xi4', 'yan2']]\n",
      "\n",
      "[['绿', '是', '阳', '春', '烟', '景', '大', '块', '文', '章', '的', '底', '色', '四', '月', '的', '林', '峦', '更', '是', '绿', '得', '鲜', '活', '秀', '媚', '诗', '意', '盎', '然'], ['他', '仅', '凭', '腰', '部', '的', '力', '量', '在', '泳', '道', '上', '下', '翻', '腾', '蛹', '动', '蛇', '行', '状', '如', '海', '豚', '一', '直', '以', '一', '头', '的', '优', '势', '领', '先'], ['炮', '眼', '打', '好', '了', '炸', '药', '怎', '么', '装', '岳', '正', '才', '咬', '了', '咬', '牙', '倏', '地', '脱', '去', '衣', '服', '光', '膀', '子', '冲', '进', '了', '水', '窜', '洞'], ['可', '谁', '知', '纹', '完', '后', '她', '一', '照', '镜', '子', '只', '见', '左', '下', '眼', '睑', '的', '线', '又', '粗', '又', '黑', '与', '右', '侧', '明', '显', '不', '对', '称'], ['七', '十', '年', '代', '末', '我', '外', '出', '求', '学', '母', '亲', '叮', '咛', '我', '吃', '饭', '要', '细', '嚼', '慢', '咽', '学', '习', '要', '深', '钻', '细', '研']]\n"
     ]
    }
   ],
   "source": [
    "print(inputs[:5])\n",
    "print()\n",
    "print(labels[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 构造输入输出词典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 9116.07it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4774.61it/s]\n"
     ]
    }
   ],
   "source": [
    "def get_vocab(data):\n",
    "    vocab = ['<PAD>']\n",
    "    for line in tqdm(data):\n",
    "        for char in line:\n",
    "            if char not in vocab:\n",
    "                vocab.append(char)\n",
    "    return vocab\n",
    "\n",
    "pny2id = get_vocab(inputs)\n",
    "han2id = get_vocab(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<PAD>', 'lv4', 'shi4', 'yang2', 'chun1', 'yan1', 'jing3', 'da4', 'kuai4', 'wen2']\n",
      "['<PAD>', '绿', '是', '阳', '春', '烟', '景', '大', '块', '文']\n"
     ]
    }
   ],
   "source": [
    "print(pny2id[:10])\n",
    "print(han2id[:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- data index\n",
    "\n",
    "将字符symbol格式的文本数据通过字典转化为index形式的数字形式的表示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5898.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3931.01it/s]\n"
     ]
    }
   ],
   "source": [
    "input_num = [[pny2id.index(pny) for pny in line] for line in tqdm(inputs)]\n",
    "label_num = [[han2id.index(han) for han in line] for line in tqdm(labels)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 数据生成器\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  11  16  17\n",
      "   18   2   1  19  20  21  22  23  24  25  26  27   0   0   0]\n",
      " [ 28  29  30  31  32  11  33  34  35  36  37  38  39  40  41  36  42  43\n",
      "   44  45  46  47  48  49  50  51  49  52  11  53   2  54  20]\n",
      " [ 55  56  57  58  59  60  61  62  63  64  15  65  66  67  59  67  68  69\n",
      "   11  70  71  49  72  73  74  75  76  77  59  78  79  42   0]\n",
      " [ 80  81  82   9  83  84  28  49  85  86  75  87  88  89  39  56  90  11\n",
      "   91  92  93  92  94  95  92  96  97  98  32  99 100   0   0]]\n",
      "[[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  11  16  17\n",
      "   18   2   1  19  20  21  22  23  24  25  26  27   0   0   0]\n",
      " [ 28  29  30  31  32  11  33  34  35  36  37  38  39  40  41  42  43  44\n",
      "   45  46  47  48  49  50  51  52  50  53  11  54  55  56  57]\n",
      " [ 58  59  60  61  62  63  64  65  66  67  68  69  70  71  62  71  72  73\n",
      "   74  75  76  77  78  79  80  81  82  83  62  84  85  86   0]\n",
      " [ 87  88  89  90  91  92  93  50  94  95  81  96  97  98  39  59  99  11\n",
      "  100 101 102 101 103 104 105 106 107 108 109 110 111   0   0]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "def get_batch(input_data, label_data, batch_size):\n",
    "    batch_num = len(input_data) // batch_size\n",
    "    for k in range(batch_num):\n",
    "        begin = k * batch_size\n",
    "        end = begin + batch_size\n",
    "        input_batch = input_data[begin:end]\n",
    "        label_batch = label_data[begin:end]\n",
    "        max_len = max([len(line) for line in input_batch])\n",
    "        input_batch = np.array([line + [0] * (max_len - len(line)) for line in input_batch])\n",
    "        label_batch = np.array([line + [0] * (max_len - len(line)) for line in label_batch])\n",
    "        yield input_batch, label_batch\n",
    "        \n",
    "        \n",
    "batch = get_batch(input_num, label_num, 4)\n",
    "input_batch, label_batch = next(batch)\n",
    "print(input_batch)\n",
    "print(label_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据处理部分先到这里，有了词典和数据，就能将符号转化为数值形式的索引号了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 模型搭建\n",
    "- 模型结构如下：\n",
    "![1545013574%281%29.jpg](attachment:1545013574%281%29.jpg)\n",
    "\n",
    "- 其中CBHG结构如下：\n",
    "![1545013753%281%29.jpg](attachment:1545013753%281%29.jpg)\n",
    "\n",
    "CBHG模块由1-D convolution bank ，highway network ，bidirectional GRU 组成。\n",
    "它的功能是从输入中提取有价值的特征，有利于提高模型的泛化能力。这里直接借用原作者代码，给出简要介绍。\n",
    "\n",
    "### embedding层\n",
    "光有对应的id，没法很好的表征文本信息，这里就涉及到构造词向量，关于词向量不在说明，网上有很多资料，模型中使用词嵌入层，通过训练不断的学习到语料库中的每个字的词向量，代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed(inputs, vocab_size, num_units, zero_pad=True, scope=\"embedding\", reuse=None):\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        lookup_table = tf.get_variable('lookup_table',\n",
    "                                       dtype=tf.float32,\n",
    "                                       shape=[vocab_size, num_units],\n",
    "                                       initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))\n",
    "        if zero_pad:\n",
    "            lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),\n",
    "                                      lookup_table[1:, :]), 0)\n",
    "    return tf.nn.embedding_lookup(lookup_table, inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoder pre-net module\n",
    "embeding layer之后是一个encoder pre-net模块，它有两个隐藏层，层与层之间的连接均是全连接；\n",
    "第一层的隐藏单元数目与输入单元数目一致，\n",
    "第二层的隐藏单元数目为第一层的一半；两个隐藏层采用的激活函数均为ReLu，并保持0.5的dropout来提高泛化能力 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prenet(inputs, num_units=None, is_training=True, scope=\"prenet\", reuse=None, dropout_rate=0.2):\n",
    "    '''Prenet for Encoder and Decoder1.\n",
    "    Args:\n",
    "      inputs: A 2D or 3D tensor.\n",
    "      num_units: A list of two integers. or None.\n",
    "      is_training: A python boolean.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "\n",
    "    Returns:\n",
    "      A 3D tensor of shape [N, T, num_units/2].\n",
    "    '''\n",
    "\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        outputs = tf.layers.dense(inputs, units=num_units[0], activation=tf.nn.relu, name=\"dense1\")\n",
    "        outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=is_training, name=\"dropout1\")\n",
    "        outputs = tf.layers.dense(outputs, units=num_units[1], activation=tf.nn.relu, name=\"dense2\")\n",
    "        outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=is_training, name=\"dropout2\")\n",
    "    return outputs  # (N, ..., num_units[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### conv bank\n",
    "输入序列首先会经过一个卷积层，注意这个卷积层，它有K个大小不同的1维的filter，其中filter的大小为1,2,3…K。\n",
    "这些大小不同的卷积核提取了长度不同的上下文信息。其实就是n-gram语言模型的思想，K的不同对应了不同的gram, \n",
    "例如unigrams, bigrams, up to K-grams，然后，将经过不同大小的k个卷积核的输出堆积在一起\n",
    "（注意：在做卷积时，运用了padding，因此这k个卷积核输出的大小均是相同的），\n",
    "也就是把不同的gram提取到的上下文信息组合在一起，下一层为最大池化层，stride为1，width为2。\n",
    "- 定义一个卷积层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv1d(inputs,\n",
    "       filters=None,\n",
    "       size=1,\n",
    "       rate=1,\n",
    "       padding=\"SAME\",\n",
    "       use_bias=False,\n",
    "       activation_fn=None,\n",
    "       scope=\"conv1d\",\n",
    "       reuse=None):\n",
    "    '''\n",
    "    Args:\n",
    "      inputs: A 3-D tensor with shape of [batch, time, depth].\n",
    "      filters: An int. Number of outputs (=activation maps)\n",
    "      size: An int. Filter size.\n",
    "      rate: An int. Dilation rate.\n",
    "      padding: Either `same` or `valid` or `causal` (case-insensitive).\n",
    "      use_bias: A boolean.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "\n",
    "    Returns:\n",
    "      A masked tensor of the same shape and dtypes as `inputs`.\n",
    "    '''    \n",
    "    with tf.variable_scope(scope):\n",
    "        if padding.lower() == \"causal\":\n",
    "            # pre-padding for causality\n",
    "            pad_len = (size - 1) * rate  # padding size\n",
    "            inputs = tf.pad(inputs, [[0, 0], [pad_len, 0], [0, 0]])\n",
    "            padding = \"valid\"\n",
    "\n",
    "        if filters is None:\n",
    "            filters = inputs.get_shape().as_list[-1]\n",
    "\n",
    "        params = {\"inputs\": inputs, \"filters\": filters, \"kernel_size\": size,\n",
    "                  \"dilation_rate\": rate, \"padding\": padding, \"activation\": activation_fn,\n",
    "                  \"use_bias\": use_bias, \"reuse\": reuse}\n",
    "\n",
    "        outputs = tf.layers.conv1d(**params)\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 由不同kernel size的卷积，组合而成的卷积块\n",
    "\n",
    "参数为：\n",
    "\n",
    "1. N: batch size\n",
    "1. T: time steps\n",
    "1. C: embedding hidden units"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv1d_banks(inputs, num_units=None, K=16, is_training=True, scope=\"conv1d_banks\", reuse=None):\n",
    "    '''Applies a series of conv1d separately.\n",
    "\n",
    "    Args:\n",
    "      inputs: A 3d tensor with shape of [N, T, C]\n",
    "      K: An int. The size of conv1d banks. That is,\n",
    "        The `inputs` are convolved with K filters: 1, 2, ..., K.\n",
    "      is_training: A boolean. This is passed to an argument of `batch_normalize`.\n",
    "\n",
    "    Returns:\n",
    "      A 3d tensor with shape of [N, T, K*Hp.embed_size//2].\n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        outputs = conv1d(inputs, num_units // 2, 1)  # k=1\n",
    "        for k in range(2, K + 1):  # k = 2...K\n",
    "            with tf.variable_scope(\"num_{}\".format(k)):\n",
    "                output = conv1d(inputs, num_units, k)\n",
    "                outputs = tf.concat((outputs, output), -1)\n",
    "        outputs = normalize(outputs, is_training=is_training,\n",
    "                            activation_fn=tf.nn.relu)\n",
    "    return outputs  # (N, T, Hp.embed_size//2*K)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### gru\n",
    "如果换成lstm有可能效果更好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gru(inputs, num_units=None, bidirection=False, seqlen=None, scope=\"gru\", reuse=None):\n",
    "    '''Applies a GRU.\n",
    "\n",
    "    Args:\n",
    "      inputs: A 3d tensor with shape of [N, T, C].\n",
    "      num_units: An int. The number of hidden units.\n",
    "      bidirection: A boolean. If True, bidirectional results\n",
    "        are concatenated.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "\n",
    "    Returns:\n",
    "      If bidirection is True, a 3d tensor with shape of [N, T, 2*num_units],\n",
    "        otherwise [N, T, num_units].\n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        if num_units is None:\n",
    "            num_units = inputs.get_shape().as_list[-1]\n",
    "\n",
    "        cell = tf.contrib.rnn.GRUCell(num_units)\n",
    "        if bidirection:\n",
    "            cell_bw = tf.contrib.rnn.GRUCell(num_units)\n",
    "            outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell, cell_bw, inputs,\n",
    "                                                         sequence_length=seqlen,\n",
    "                                                         dtype=tf.float32)\n",
    "            return tf.concat(outputs, 2)\n",
    "        else:\n",
    "            outputs, _ = tf.nn.dynamic_rnn(cell, inputs,\n",
    "                                           sequence_length=seqlen,\n",
    "                                           dtype=tf.float32)\n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### highwaynet\n",
    "下一层输入到highway layers，highway nets的每一层结构为：把输入同时放入到两个一层的全连接网络中，\n",
    "这两个网络的激活函数分别采用了ReLu和sigmoid函数，假定输入为input，ReLu的输出为output1，sigmoid的输出为output2，\n",
    "那么highway layer的输出为output=output1∗output2+input∗（1−output2)。论文中使用了4层highway layer。\n",
    "代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def highwaynet(inputs, num_units=None, scope=\"highwaynet\", reuse=None):\n",
    "    '''Highway networks, see https://arxiv.org/abs/1505.00387\n",
    "    Args:\n",
    "      inputs: A 3D tensor of shape [N, T, W].\n",
    "      num_units: An int or `None`. Specifies the number of units in the highway layer\n",
    "             or uses the input size if `None`.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "    Returns:\n",
    "      A 3D tensor of shape [N, T, W].\n",
    "    '''\n",
    "    if not num_units:\n",
    "        num_units = inputs.get_shape()[-1]\n",
    "\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        H = tf.layers.dense(inputs, units=num_units, activation=tf.nn.relu, name=\"dense1\")\n",
    "        T = tf.layers.dense(inputs, units=num_units, activation=tf.nn.sigmoid,\n",
    "                            bias_initializer=tf.constant_initializer(-1.0), name=\"dense2\")\n",
    "        C = 1. - T\n",
    "        outputs = H * T + inputs * C\n",
    "\n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### batch normalize\n",
    "\n",
    "使用bn层，加速训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(inputs,\n",
    "              decay=.99,\n",
    "              epsilon=1e-8,\n",
    "              is_training=True,\n",
    "              activation_fn=None,\n",
    "              reuse=None,\n",
    "              scope=\"normalize\"):\n",
    "    '''Applies {batch|layer} normalization.\n",
    "\n",
    "    Args:\n",
    "      inputs: A tensor with 2 or more dimensions, where the first dimension has\n",
    "        `batch_size`. If type is `bn`, the normalization is over all but\n",
    "        the last dimension. Or if type is `ln`, the normalization is over\n",
    "        the last dimension. Note that this is different from the native\n",
    "        `tf.contrib.layers.batch_norm`. For this I recommend you change\n",
    "        a line in ``tensorflow/contrib/layers/python/layers/layer.py`\n",
    "        as follows.\n",
    "        Before: mean, variance = nn.moments(inputs, axis, keep_dims=True)\n",
    "        After: mean, variance = nn.moments(inputs, [-1], keep_dims=True)\n",
    "      type: A string. Either \"bn\" or \"ln\".\n",
    "      decay: Decay for the moving average. Reasonable values for `decay` are close\n",
    "        to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.\n",
    "        Lower `decay` value (recommend trying `decay`=0.9) if model experiences\n",
    "        reasonably good training performance but poor validation and/or test\n",
    "        performance.\n",
    "      is_training: Whether or not the layer is in training mode. W\n",
    "      activation_fn: Activation function.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "\n",
    "    Returns:\n",
    "      A tensor with the same shape and data dtype as `inputs`.\n",
    "    '''\n",
    "    inputs_shape = inputs.get_shape()\n",
    "    inputs_rank = inputs_shape.ndims\n",
    "\n",
    "    # use fused batch norm if inputs_rank in [2, 3, 4] as it is much faster.\n",
    "    # pay attention to the fact that fused_batch_norm requires shape to be rank 4 of NHWC.\n",
    "    inputs = tf.expand_dims(inputs, axis=1)\n",
    "    outputs = tf.contrib.layers.batch_norm(inputs=inputs,\n",
    "                                            decay=decay,\n",
    "                                            center=True,\n",
    "                                            scale=True,\n",
    "                                            updates_collections=None,\n",
    "                                            is_training=is_training,\n",
    "                                            scope=scope,\n",
    "                                            zero_debias_moving_mean=True,\n",
    "                                            fused=True,\n",
    "                                            reuse=reuse)\n",
    "    outputs = tf.squeeze(outputs, axis=1)\n",
    "\n",
    "    if activation_fn:\n",
    "        outputs = activation_fn(outputs)\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 搭建模型\n",
    "由各个组件构成模型。\n",
    "- 模型结构如下：\n",
    "![1545013574%281%29.jpg](attachment:1545013574%281%29.jpg)\n",
    "\n",
    "- 其中CBHG结构如下：\n",
    "![1545013753%281%29.jpg](attachment:1545013753%281%29.jpg)\n",
    "\n",
    "- 模型代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Graph():\n",
    "    '''Builds a model graph'''\n",
    "\n",
    "    def __init__(self, arg):\n",
    "        tf.reset_default_graph()\n",
    "        self.pny_size = arg.pny_size\n",
    "        self.han_size = arg.han_size\n",
    "        self.embed_size = arg.embed_size\n",
    "        self.is_training = arg.is_training\n",
    "        self.num_highwaynet_blocks = arg.num_highwaynet_blocks\n",
    "        self.encoder_num_banks = arg.encoder_num_banks\n",
    "        self.lr = arg.lr\n",
    "        \n",
    "        self.x = tf.placeholder(tf.int32, shape=(None, None))\n",
    "        self.y = tf.placeholder(tf.int32, shape=(None, None))\n",
    "        \n",
    "        # Character Embedding for x\n",
    "        enc = embed(self.x, self.pny_size, self.embed_size, scope=\"emb_x\")\n",
    "        # Encoder pre-net\n",
    "        prenet_out = prenet(enc,\n",
    "                            num_units=[self.embed_size, self.embed_size // 2],\n",
    "                            is_training=self.is_training)  # (N, T, E/2)\n",
    "\n",
    "        # Encoder CBHG\n",
    "        ## Conv1D bank\n",
    "        enc = conv1d_banks(prenet_out,\n",
    "                            K=self.encoder_num_banks,\n",
    "                            num_units=self.embed_size // 2,\n",
    "                            is_training=self.is_training)  # (N, T, K * E / 2)\n",
    "\n",
    "        ## Max pooling\n",
    "        enc = tf.layers.max_pooling1d(enc, 2, 1, padding=\"same\")  # (N, T, K * E / 2)\n",
    "\n",
    "        ## Conv1D projections\n",
    "        enc = conv1d(enc, self.embed_size // 2, 5, scope=\"conv1d_1\")  # (N, T, E/2)\n",
    "        enc = normalize(enc, is_training=self.is_training,\n",
    "                            activation_fn=tf.nn.relu, scope=\"norm1\")\n",
    "        enc = conv1d(enc, self.embed_size // 2, 5, scope=\"conv1d_2\")  # (N, T, E/2)\n",
    "        enc = normalize(enc, is_training=self.is_training,\n",
    "                            activation_fn=None, scope=\"norm2\")\n",
    "        enc += prenet_out  # (N, T, E/2) # residual connections\n",
    "\n",
    "        ## Highway Nets\n",
    "        for i in range(self.num_highwaynet_blocks):\n",
    "            enc = highwaynet(enc, num_units=self.embed_size // 2,\n",
    "                                scope='highwaynet_{}'.format(i))  # (N, T, E/2)\n",
    "\n",
    "        ## Bidirectional GRU\n",
    "        enc = gru(enc, self.embed_size // 2, True, scope=\"gru1\")  # (N, T, E)\n",
    "\n",
    "        ## Readout\n",
    "        self.outputs = tf.layers.dense(enc, self.han_size, use_bias=False)\n",
    "        self.preds = tf.to_int32(tf.argmax(self.outputs, axis=-1))\n",
    "\n",
    "        if self.is_training:\n",
    "            self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, logits=self.outputs)\n",
    "            self.istarget = tf.to_float(tf.not_equal(self.y, tf.zeros_like(self.y)))  # masking\n",
    "            self.hits = tf.to_float(tf.equal(self.preds, self.y)) * self.istarget\n",
    "            self.acc = tf.reduce_sum(self.hits) / tf.reduce_sum(self.istarget)\n",
    "            self.mean_loss = tf.reduce_sum(self.loss * self.istarget) / tf.reduce_sum(self.istarget)\n",
    "\n",
    "            # Training Scheme\n",
    "            self.global_step = tf.Variable(0, name='global_step', trainable=False)\n",
    "            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)\n",
    "            self.train_op = self.optimizer.minimize(self.mean_loss, global_step=self.global_step)\n",
    "\n",
    "            # Summary\n",
    "            tf.summary.scalar('mean_loss', self.mean_loss)\n",
    "            tf.summary.scalar('acc', self.acc)\n",
    "            self.merged = tf.summary.merge_all()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 参数设定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hparams():\n",
    "    params = tf.contrib.training.HParams(\n",
    "        \n",
    "        # vocab\n",
    "        pny_size = 50,\n",
    "        han_size = 50,\n",
    "        # embedding size\n",
    "        embed_size = 300,\n",
    "        num_highwaynet_blocks = 4,\n",
    "        encoder_num_banks = 8,\n",
    "        lr = 0.001,\n",
    "        is_training = True)\n",
    "    return params\n",
    "\n",
    "arg = create_hparams()\n",
    "arg.pny_size = len(pny2id)\n",
    "arg.han_size = len(han2id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From d:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "`NHWC` for data_format is deprecated, use `NWC` instead\n",
      "INFO:tensorflow:Restoring parameters from logs/model\n",
      "epochs 5 : average loss =  1.2706838726997376\n",
      "epochs 10 : average loss =  0.7730961179733277\n",
      "epochs 15 : average loss =  0.4291091418266296\n",
      "epochs 20 : average loss =  0.2579580122232437\n",
      "epochs 25 : average loss =  0.17582486391067506\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "epochs = 25\n",
    "batch_size = 4\n",
    "\n",
    "g = Graph(arg)\n",
    "\n",
    "saver =tf.train.Saver()\n",
    "with tf.Session() as sess:\n",
    "    merged = tf.summary.merge_all()\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    if os.path.exists('logs/model.meta'):\n",
    "        saver.restore(sess, 'logs/model')\n",
    "    writer = tf.summary.FileWriter('tensorboard/lm', tf.get_default_graph())\n",
    "    for k in range(epochs):\n",
    "        total_loss = 0\n",
    "        batch_num = len(input_num) // batch_size\n",
    "        batch = get_batch(input_num, label_num, batch_size)\n",
    "        for i in range(batch_num):\n",
    "            input_batch, label_batch = next(batch)\n",
    "            feed = {g.x: input_batch, g.y: label_batch}\n",
    "            cost,_ = sess.run([g.mean_loss,g.train_op], feed_dict=feed)\n",
    "            total_loss += cost\n",
    "            if (k * batch_num + i) % 10 == 0:\n",
    "                rs=sess.run(merged, feed_dict=feed)\n",
    "                writer.add_summary(rs, k * batch_num + i)\n",
    "        if (k+1) % 5 == 0:\n",
    "            print('epochs', k+1, ': average loss = ', total_loss/batch_num)\n",
    "    saver.save(sess, 'logs/model')\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型推断"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from logs/model\n",
      "输入测试拼音: lv4 shi4 yang2 chun1 yan1 jing3 da4 kuai4 wen2 zhang1 de di3 se4 si4 yue4 de lin2 luan2 geng4 shi4 lv4 de2 xian1 huo2 xiu4 mei4 shi1 yi4 ang4 ran2\n",
      "绿是阳春烟景大块文章的底色四月的林峦更是绿得鲜秀秀秀然回债权\n",
      "输入测试拼音: exit\n"
     ]
    }
   ],
   "source": [
    "arg.is_training = False\n",
    "\n",
    "g = Graph(arg)\n",
    "\n",
    "saver =tf.train.Saver()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, 'logs/model')\n",
    "    while True:\n",
    "        line = input('输入测试拼音: ')\n",
    "        if line == 'exit': break\n",
    "        line = line.strip('\\n').split(' ')\n",
    "        x = np.array([pny2id.index(pny) for pny in line])\n",
    "        x = x.reshape(1, -1)\n",
    "        preds = sess.run(g.preds, {g.x: x})\n",
    "        got = ''.join(han2id[idx] for idx in preds[0])\n",
    "        print(got)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
